# -*- coding: utf-8 -*-
"""

GarpMaker is a class for randomly generating datasets which satisfy e-garp. 

Calling GarpMaker( N, L, e ) will return an object for making random e-rp_garp
consistent data with N obs and L goods.

Let gm be an GarpMaker object. gm(M) returns a dataframe with M consumers.

"""

import numpy as np
import pandas as pd
from random import shuffle
from numpy.random import rand, standard_normal
from data_manager import mats_to_df


class GarpMaker():
    '''A Class for randomly generating datasets which satisfy e-garp. 
    Calling GarpMaker( N, L, e ) will return an object for making random e-garp 
    consistent data with N obs and L goods.
    
    mode = b or r to change between uniform budget shares and uniform rays.
    
    If gm is an instance of GarpMaker then calling gm(M) will create a dataframe
    with M consumers. They all have the same budget sets.'''
    
    def __init__(self, N, L, e=1., mode='b'):
        '''Generate the buget set matrix P and inital C.'''
        
        self.N = N; self.L = L; self.e = e
        
        self.P = self._price_gen() #Make the price matrix.
        b = np.ones(L) / L #Constant budget share vector.
        #Make C with constant budget shares and income 10. 
        
        C = [ ( 10. * b ) / self.P[n] for n in range(N) ]
        self.C = np.array(C)
        
        self.Clist = []
        #The V mat will represent the values of the bundles at different prices.
        #This matrix needs to be updated whenever C is.
        #The matrix allows for quick running of garp.
        self.V = self.C @ self.P.T
        if (mode != 'b') and (mode != 'r'):
            raise Exception('mode must be b or r')
        self.mode = mode
        
    def __call__(self,M):
        '''Generate M random e-RP datasets using budgets of self.P'''
        
        burninN = 100 #number of times to run before first consumer.
        nextdataN = 10 #number of times to run between consumers
        
        #Do burnin
        for k in range(burninN):
            self._alter()
         
        #Now make the consumer data.
        for m in range(M):
            for k in range(nextdataN):
                self._alter()
            self.Clist.append( self.C.copy() )
        
        return self._output_data()
               
        
    #=========================================================================
    #Alter
    #=========================================================================    
    
    def _alter(self):
        '''Alter permutes 1:N and then alters obs one at a time. 
        For each n pick a random direction d to move in and a random weight u. 
        Will move u percent of the way to the e-rp_test boundary in direction d.'''
        
        ob_list = list(range(self.N)) #ob_list to shuffle
        shuffle(ob_list) #ob_list is shuffled randomly.
        for n in ob_list:
            d = self._rand_dir( self.P[n], self.C[n] ) #Random direction.
            u = rand() #Random weight
            #Alter observation n by moving u percent of the way to the e-garp boundary.
            #Moving in direction d.
            self._alter_implement(n,d,u)
    
    
    def _alter_implement(self,n,d,u):
        '''Pass an ob number n a direction d in the budget set of ob n and a number
        u between 0 and 1. alter_implement will change C[n] so that you move u percent of the 
        way to the e-rp_test boundary in direction d.'''
        C = self.C; P = self.P; V = self.V
        c_old = np.copy(C[n]) #_garp_edge will mangle C[n] so copy it now.        
        
        c_edge = self._rp_test_edge(n,d)
        
        #If mode is b then move u percent to c_edge. 
        #If mode is r then move u percent to c_edge in ray space.
        
        if self.mode == 'b':
            #Bundle which is u percent of the way to the garp boundary
            c_new = (1. - u) * c_old + u * c_edge
            #Make sure income is 10.
            c_new = 10. * c_new / ( P[n] @ c_new )
        elif self.mode == 'r':
            r_old = c_old / c_old.sum()
            r_edge = c_edge / c_edge.sum()
            r_new = (1. - u) * r_old + u * r_edge
            c_new = 10. * r_new / ( P[n] @ r_new )
        else:
            raise Exception('Invalid mode')
        
        
        #Change observation n
        C[n] = c_new
        V[n] = P @ C[n]
        
        
    #=========================================================================
    #Alter Helpers
    #=========================================================================
      
    #============
    #edge finders
    #============
    def _rp_test_edge(self,n,d):
        '''Given an observation number n and a direction d will return the 
        bundle c which is the most extreme vector in definition d
        which still satisfies e-rp_test.'''
        
        C = self.C; P = self.P; V = self.V; e = self.e
        
        c_old = np.copy( C[n] ) #Get the current consumption bundle for ob n.
        c_edge = self._budget_edge(c_old,d,P[n]) #Edge of consumption space in direction d
        #Find the weight k so that (1-k) c_old + k c_edge just satisfies e-rp_test.
        #Note that C[n] = (1-k_low) c_old + k_low c_edge should always satisfy e-rp_test.
        
        #======================================================================
        #Binary Search
        #======================================================================
        k_high = 1.0; k_low = 0.0
        while k_high - k_low >= 2**(-8):
            k = ( k_high + k_low ) / 2.
            c = (1.-k) * c_old + k * c_edge
            C[n] = 10. * c / ( P[n] @ c ) #Ensure income is always 10.
            V[n] = P @ C[n]
            if not self._cycle(n,e):
                #C satisfies e-garp_test
                k_low = k
            else:
                #C does not satisfy e-garp
                k_high = k
        c = (1.-k_low) * c_old + k_low * c_edge #Get the c which satisfies e-rp_test.
        c = 10. * c / ( P[n] @ c ) #Make sure income is 100.
        return c
            
    def _budget_edge(self,c,d,p):
        '''Start at bundle c. Move in direction d until you hit
        the edge of the consumption space. Return this point.
        
        c should be non-negative valued. If mode is b then d is a direction
        in consumption space. If mode is r then d is a direction in ray space.'''
        
        if self.mode == 'b':
            m = -c / d #Used to figure out the max distance I can travel in direction d.
            k = min( entry for entry in m if entry > 0. ) #Max distance I can move in direction d
            edge = c + k * d
        elif self.mode == 'r':
            r = c / c.sum() #Pull c into the unit simplex.
            m = -r / d
            #Max distance I can move in direction d
            k = min(  entry for entry in m if entry > 0. )
            edge_r = r + k * d
            edge = 10. * edge_r / ( p @ edge_r )
        else:
            raise Exception('Invalid mode.')
        
        return edge
    
    def _rand_dir( self, p, c ):
        '''Returns a random vector v.  
        If mode is b then is a random direction corresponds to a uniform random direction
        in budget share space but mapped back into consumption space.
        If mode is r then returns is a uniform direction in ray space without
        mapping back to consumption space.'''
        
        z = standard_normal(self.L)
        z = z - np.mean(z) #Random dirction in budget share space.
        if self.mode == 'b':
            v = z / p #random direction in the budget set with price p.
        elif self.mode == 'r':
            v = z
        else:
            raise Exception('passed invalid mode.')
        return v

    #=========================================================================
    #Output Data
    #=========================================================================    
    def _output_data(self):
        '''Turns the consumption data in Clist into a dataframe and returns it.'''
    
        Clist = self.Clist
        P = self.P
        M = len(Clist)
        
        #Make a list of dataframes. One dataframe for each consumer.
        df_list = [ mats_to_df(Clist[m],P,m) for m in range(M) ]
        
        #Concatenate the dataframes and return.
        return pd.concat(df_list)

    #=========================================================================
    #Make Budgets
    #=========================================================================    
    def _price_gen(self):
        '''Returns an N x L matrix. Under the assumption that w = 10 the price
        vectors ensure that to each good l one may purchase anywhere between 1 and 10
        units of the good if all income was allocated to this good. The amount of good
        l which can be purchased will be uniformly distributed and independent across goods.'''
        
        N, L = (self.N, self.L)
        Pbase = 1. + 9. * rand(N,L) #Randomly generate an N x L matrix uniform [0,1) and rescale.
        return 10. / Pbase

    #=========================================================================
    #Find Path
    #=========================================================================  
    def _cycle( self, n, e = 1.0):
        '''
        Will try to find a cycle involving observation n at efficiency level e.
        '''
        V = self.V
        
        N = V.shape[0] #Number of obs
            
        #Check if there is a cycle involving ob n.
        visited = N * [False] #If an element enters here we disregard it.
        queue = [n] #Observations left to check.
        while len(queue):
            m = queue.pop() #We will see what ob m is e-revealed preferred to
            for i in range(N):
                #The for loop does not consider obs we know aren't in a cycle.
                if visited[i] == False: #Only check out obs that haven't been visited
                    if e * V[m,m] > V[i,m]: #See if rp
                        #m is e-revealed preferred to i
                        if i == n: #If i is n then we've found a cycle
                            return True
                        queue.append(i) #We know there's a path from obnum to i. 
                        visited[i] = True
        return False


